#!/usr/bin/env python
import rospy

import gym
import keras
import numpy as np
import random

from gym import wrappers
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam

from collections import deque

from openai_ros.msg import RLExperimentInfo

ACTIONS_DIM = 2
OBSERVATIONS_DIM = 4
MAX_ITERATIONS = 10**6
LEARNING_RATE = 0.001

NUM_EPOCHS = 50

GAMMA = 0.99
REPLAY_MEMORY_SIZE = 1000
NUM_EPISODES = 10000
TARGET_UPDATE_FREQ = 100
MINIBATCH_SIZE = 32

RANDOM_ACTION_DECAY = 0.99
INITIAL_RANDOM_ACTION = 1

class ReplayBuffer():

  def __init__(self, max_size):
    self.max_size = max_size
    self.transitions = deque()

  def add(self, observation, action, reward, observation2):
    if len(self.transitions) > self.max_size:
      self.transitions.popleft()
    self.transitions.append((observation, action, reward, observation2))

  def sample(self, count):
    return random.sample(self.transitions, count)

  def size(self):
    return len(self.transitions)

def get_q(model, observation):
  np_obs = np.reshape(observation, [-1, OBSERVATIONS_DIM])
  return model.predict(np_obs)

def train(model, observations, targets):
  # for i, observation in enumerate(observations):
  #   np_obs = np.reshape(observation, [-1, OBSERVATIONS_DIM])
  #   print "t: {}, p: {}".format(model.predict(np_obs),targets[i])
  # exit(0)

  np_obs = np.reshape(observations, [-1, OBSERVATIONS_DIM])
  np_targets = np.reshape(targets, [-1, ACTIONS_DIM])

  #model.fit(np_obs, np_targets, epochs=1, verbose=0)
  model.fit(np_obs, np_targets, nb_epoch=1, verbose=0)

def predict(model, observation):
  np_obs = np.reshape(observation, [-1, OBSERVATIONS_DIM])
  return model.predict(np_obs)

def get_model():
  model = Sequential()
  model.add(Dense(16, input_shape=(OBSERVATIONS_DIM, ), activation='relu'))
  model.add(Dense(16, input_shape=(OBSERVATIONS_DIM,), activation='relu'))
  model.add(Dense(2, activation='linear'))

  model.compile(
    optimizer=Adam(lr=LEARNING_RATE),
    loss='mse',
    metrics=[],
  )

  return model

def update_action(action_model, target_model, sample_transitions):
  random.shuffle(sample_transitions)
  batch_observations = []
  batch_targets = []

  for sample_transition in sample_transitions:
    old_observation, action, reward, observation = sample_transition

    targets = np.reshape(get_q(action_model, old_observation), ACTIONS_DIM)
    targets[action] = reward
    if observation is not None:
      predictions = predict(target_model, observation)
      new_action = np.argmax(predictions)
      targets[action] += GAMMA * predictions[0, new_action]

    batch_observations.append(old_observation)
    batch_targets.append(targets)

  train(action_model, batch_observations, batch_targets)


class PublishRewardClass(object):
    def __init__(self):
        # Set up ROS related variables
        self.episode_num = 0
        self.cumulated_episode_reward = 0
        # Start Reward publishing
        self.reward_pub = rospy.Publisher('/openai/reward', RLExperimentInfo, queue_size=1)
        
    def _publish_reward_topic(self, reward, episode_number=1):
        """
        This function publishes the given reward in the reward topic for
        easy access from ROS infrastructure.
        :param reward:
        :param episode_number:
        :return:
        """
        reward_msg = RLExperimentInfo()
        reward_msg.episode_number = episode_number
        reward_msg.episode_reward = reward
        self.reward_pub.publish(reward_msg)
        
    def _update_episode(self):
        """
        Publishes the cumulated reward of the episode and 
        increases the episode number by one.
        :return:
        """
        self._publish_reward_topic(
                                    self.cumulated_episode_reward,
                                    self.episode_num
                                    )
        self.episode_num += 1
        self.cumulated_episode_reward = 0
        
    def update_cumulated_reward(self, reward):
        """
        Increase reward
        """
        self.cumulated_episode_reward += reward

def main():
  steps_until_reset = TARGET_UPDATE_FREQ
  random_action_probability = INITIAL_RANDOM_ACTION

  # Initialize replay memory D to capacity N
  replay = ReplayBuffer(REPLAY_MEMORY_SIZE)

  # Initialize action-value model with random weights
  action_model = get_model()

  # Initialize target model with same weights
  #target_model = get_model()
  #target_model.set_weights(action_model.get_weights())
  
  pub_reward_obj = PublishRewardClass()
  
  rate = rospy.Rate(100)

  env = gym.make('CartPole-v0')
  env = wrappers.Monitor(env, '/tmp/cartpole-experiment-1', force=True)

  for episode in range(NUM_EPISODES):
    observation = env.reset()

    for iteration in range(MAX_ITERATIONS):
      random_action_probability *= RANDOM_ACTION_DECAY
      random_action_probability = max(random_action_probability, 0.1)
      old_observation = observation

      if episode % 1 == 0:
        env.render()

      if np.random.random() < random_action_probability:
        action = np.random.choice(range(ACTIONS_DIM))
      else:
        q_values = get_q(action_model, observation)
        action = np.argmax(q_values)

      observation, reward, done, info = env.step(action)
      
      # Cumulate reward
      pub_reward_obj.update_cumulated_reward(reward)

      if done:
        print 'Episode {}, iterations: {}'.format(
          episode,
          iteration
        )

        # print action_model.get_weights()
        # print target_model.get_weights()

        #print 'Game finished after {} iterations'.format(iteration)
        reward = -200
        replay.add(old_observation, action, reward, None)
        
        pub_reward_obj._update_episode()
        
        break

      replay.add(old_observation, action, reward, observation)

      if replay.size() >= MINIBATCH_SIZE:
        sample_transitions = replay.sample(MINIBATCH_SIZE)
        update_action(action_model, action_model, sample_transitions)
        steps_until_reset -= 1

      # if steps_until_reset == 0:
      #   target_model.set_weights(action_model.get_weights())
      #   steps_until_reset = TARGET_UPDATE_FREQ
      
      rate.sleep()

if __name__ == "__main__":
    rospy.init_node('cartpole_mbalunovic_algorithm', anonymous=True, log_level=rospy.INFO)
    main()